import io
import pathlib

import PIL.Image
import matplotlib.pyplot
import torch
from lotka_volterra_dataset import LotkaVolterraDataset
from lotka_volterra_problem import LotkaVolterraProblem


def main():
    dataset = LotkaVolterraDataset.load(pathlib.Path(f"./checkpoints/data.pt"))
    problem = dataset.problem()
    problem = LotkaVolterraProblem(
        problem.a1(), 0, 0, 0,
        problem.x1_0(), 0)

    matplotlib.pyplot.figure(figsize=(8, 5), dpi=110)
    t = torch.linspace(0, 3, 1000)
    x1 = list(problem.x1_batch(t))
    matplotlib.pyplot.plot(t, x1)

    matplotlib.pyplot.xlabel(r"$t$")
    matplotlib.pyplot.ylabel(r"$x_1$")

    matplotlib.pyplot.show()


if __name__ == "__main__":
    main()
